import pandas as pd
from transformers import pipeline
from datasets import load_dataset
import numpy as np
import time
import torch
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

ds = load_dataset('mlburnham/Pol_NLI')
test = ds['test'].to_pandas()
ndocs = 5000
test = test.sample(ndocs, random_state = 1)
docs = list(test['premise'])

# Benchmarks dataframe
timings_df = pd.read_csv('./data/cuda_timings.csv')
# Models to benchmark
models = ['Political_DEBATE_ModernBERT_base_v1.0', 'Political_DEBATE_ModernBERT_large_v1.0']
# drop previous benchmarks
timings_df = timings_df[~timings_df['Model'].isin(models)]
# List to hold benchmarks
timings = []

def time_it(pipe, docs, n=1):
    """
    Benchmark the time taken by a model pipeline.

    Args:
        pipe (callable): The model pipeline to benchmark.
        docs (list): A list of documents that will be passed to the pipe.
        n (int): Number of times to run the benchmark. Defaults to 1.

    Returns:
        dict: A dictionary containing the model name, hardware, average time, DPS (documents per second),
              and standard errors for both metrics if n > 1.
    """
    times = []

    for i in range(n):
        # Start the timer
        start_time = time.time()
        results = pipe(docs, 'This text is about politics.', hypothesis_template='{}')
        # Stop timer
        end_time = time.time()
        # Calculate the elapsed time for this run
        elapsed_time = end_time - start_time
        times.append(elapsed_time)

        print(f"Run {i + 1}/{n} - Elapsed time: {elapsed_time:.2f} seconds")

    # Calculate the average time and DPS
    avg_time = np.mean(times)
    dps = ndocs / avg_time

    # Calculate standard errors if n > 1
    if n > 1:
        time_se = np.std(times, ddof=1) / np.sqrt(n)
        dps_se = (np.std([ndocs / t for t in times], ddof=1) / np.sqrt(n))
    else:
        time_se = None
        dps_se = None

    print(f"Average elapsed time: {avg_time:.2f} seconds")
    print(f"Average DPS: {dps}")

    if n > 1:
        print(f"Standard error (Time): {time_se:.4f} seconds")
        print(f"Standard error (DPS): {dps_se:.4f}")

    if pipe.device.type == 'mps':
        torch.mps.empty_cache()
    else:
        torch.cuda.empty_cache()

    res = {
        'Model': model.split('/')[-1],
        'Hardware': pipe.device.type,
        'Time': avg_time,
        'Time_SE': time_se,
        'DPS': dps,
        'DPS_SE': dps_se
    }
    return res

####################
## ModernBERT Large
####################
model = "mlburnham/Political_DEBATE_ModernBERT_large_v1.0"
pipe = pipeline("zero-shot-classification", model = model, device = 'cuda', batch_size = 512, torch_dtype = torch.bfloat16)

# once to compile
time_it(pipe, docs, n = 1)
# benchmark
res = time_it(pipe, docs, n = 10)
timings.append(res)

####################
## ModernBERT Base
####################
model = "mlburnham/Political_DEBATE_ModernBERT_base_v1.0"
pipe = pipeline("zero-shot-classification", model = model, device = 'cuda', batch_size = 64, torch_dtype = torch.bfloat16)

# once to compile
time_it(pipe, docs, n = 1)
# benchmark
res = time_it(pipe, docs, n = 10)
timings.append(res)

#########
## Export
#########
mb_timings = pd.DataFrame(timings)
timings_df = pd.concat([timings_df, mb_timings])
timings_df.to_csv('./data/cuda_timings.csv', index = False)